[Tests] Check WebGPU volatile allreduce annotation structurally#19740
Conversation
There was a problem hiding this comment.
Code Review
This pull request replaces a fragile string-based assertion in the thread all-reduce lower transform tests with a helper function, _has_volatile_alloc_buffer, which programmatically inspects the AST of a TVM module for volatile buffer allocations. The review feedback correctly points out that using identity comparison (is True) on TVM annotations can fail because TVM map lookups return TVM object wrappers rather than Python's built-in True singleton, and suggests using bool(...) instead.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| def visit(node): | ||
| nonlocal has_volatile_alloc | ||
| if isinstance(node, tvm.tirx.AllocBuffer) and "tirx.volatile" in node.annotations: | ||
| has_volatile_alloc = has_volatile_alloc or node.annotations["tirx.volatile"] is True |
There was a problem hiding this comment.
Using is True to check the value of a TVM annotation can lead to failures. TVM map lookups typically return TVM object wrappers (such as tvm.tir.IntImm or tvm.runtime.Bool) rather than Python's built-in True singleton, so identity comparison (is) will evaluate to False. Using bool(...) is more robust and correctly evaluates the truthiness of the TVM object.
| has_volatile_alloc = has_volatile_alloc or node.annotations["tirx.volatile"] is True | |
| has_volatile_alloc = has_volatile_alloc or bool(node.annotations["tirx.volatile"]) |
…he#19740) This pr updates the WebGPU multi-warp allreduce test to check the generated `tirx.volatile` allocation annotation structurally instead of matching the exact TVMScript printer output. The test is intended to verify that `LowerThreadAllreduce` marks the generated shared allocation as volatile. It previously checked for the exact string: ```python "tirx.volatile": T.bool(True) ``` However, the current printer emits the same annotation as: ```python annotations={"tirx.volatile": True} ``` The transform behavior is unchanged; only the printer spelling differs. This patch walks the generated TIRX body and checks for an `AllocBuffer` with `tirx.volatile=True`, which matches the actual semantic requirement of the test without depending on bool literal formatting. (cherry picked from commit 30bf568)
…he#19740) This pr updates the WebGPU multi-warp allreduce test to check the generated `tirx.volatile` allocation annotation structurally instead of matching the exact TVMScript printer output. The test is intended to verify that `LowerThreadAllreduce` marks the generated shared allocation as volatile. It previously checked for the exact string: ```python "tirx.volatile": T.bool(True) ``` However, the current printer emits the same annotation as: ```python annotations={"tirx.volatile": True} ``` The transform behavior is unchanged; only the printer spelling differs. This patch walks the generated TIRX body and checks for an `AllocBuffer` with `tirx.volatile=True`, which matches the actual semantic requirement of the test without depending on bool literal formatting. (cherry picked from commit 30bf568)
This pr updates the WebGPU multi-warp allreduce test to check the generated
tirx.volatileallocation annotation structurally instead of matching the exact TVMScript printer output.The test is intended to verify that
LowerThreadAllreducemarks the generated shared allocation as volatile. It previously checked for the exact string:However, the current printer emits the same annotation as:
The transform behavior is unchanged; only the printer spelling differs. This patch walks the generated TIRX body and checks for an
AllocBufferwithtirx.volatile=True, which matches the actual semantic requirement of the test without depending on bool literal formatting.